{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chemprop.data.datasets import (\n",
    "    CuikmolmakerDataset, MoleculeDataset, ReactionDataset, MulticomponentDataset\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make a dataset you first need a list of [datapoints](./datapoints.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from chemprop.data import LazyMoleculeDatapoint, MoleculeDatapoint, ReactionDatapoint\n",
    "\n",
    "ys = np.random.rand(2, 1)\n",
    "\n",
    "smis = [\"C\", \"CC\"]\n",
    "mol_datapoints = [MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]\n",
    "\n",
    "lazy_mol_datapoints = [LazyMoleculeDatapoint(\"C\" * i, y=[i]) for i in range(1, 20)]\n",
    "\n",
    "rxn_smis = [\"[H:2][O:1][H:3]>>[H:2][O:1].[H:3]\", \"[H:2][S:1][H:3]>>[H:2][S:1].[H:3]\"]\n",
    "rxn_datapoints = [\n",
    "    ReactionDatapoint.from_smi(rxn_smi, y, keep_h=True) for rxn_smi, y in zip(rxn_smis, ys)\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Molecule Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MoleculeDataset`s are made from a list of `MoleculeDatapoint`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed5b0>, y=array([0.23384385]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='C', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed690>, y=array([0.74433064]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f7a6b52c290>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f7a6b52c150>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MoleculeDataset(mol_datapoints)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset properties"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The properties of datapoints are collated in a dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.23384385]\n",
      " [0.74433064]]\n",
      "['C', 'CC']\n"
     ]
    }
   ],
   "source": [
    "dataset = MoleculeDataset(mol_datapoints)\n",
    "print(dataset.Y)\n",
    "print(dataset.names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets return a `Datum` when indexed. A `Datum` contains a `MolGraph` (see the [molgraph featurizer notebook](../featurizers/molgraph_molecule_featurizer.ipynb)), the extra atom and datapoint level descriptors, the target(s), the weights, and masks for bounded loss functions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Datum(mg=MolGraph(V=array([[0.     , 0.     , 0.     , 0.     , 0.     , 1.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        1.     , 0.     , 0.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        1.     , 0.     , 1.     , 0.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 0.     , 1.     , 0.     , 0.     ,\n",
       "        0.     , 0.     , 0.     , 1.     , 0.     , 0.     , 0.     ,\n",
       "        0.     , 0.12011]], dtype=float32), E=array([], shape=(0, 14), dtype=float64), edge_index=array([], shape=(2, 0), dtype=int64), rev_edge_index=array([], dtype=int64)), V_d=None, x_d=None, y=array([0.23384385]), weight=1.0, lt_mask=None, gt_mask=None)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Caching"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `MolGraph`s are generated as needed by default. For small to medium dataset (exact sizes not yet benchmarked), it is more efficient to generate and cache the molgraphs when the dataset is created. \n",
    "\n",
    "If the cache needs to be recreated, set the cache to True again. To clear the cache, set it to False. \n",
    "\n",
    "Note we recommend [scaling](../scaling.ipynb) additional atom and bond features before setting the cache, as scaling them after caching will require the cache to be recreated, which is done automatically. \n",
    "\n",
    "To featurize the graphs in parallel when caching, use the `n_workers` argument when creating the dataset. Note that this may cause hangs on Windows and MacOS. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "if sys.platform not in [\"win32\", \"darwin\"]:\n",
    "    dataset = MoleculeDataset(mol_datapoints, n_workers=3)\n",
    "else:\n",
    "    dataset = MoleculeDataset(mol_datapoints)\n",
    "\n",
    "dataset.cache = True  # Generate the molgraphs and cache them\n",
    "dataset.cache = True  # Recreate the cache\n",
    "dataset.cache = False  # Clear the cache\n",
    "\n",
    "dataset.cache = True  # Cache created with unscaled extra bond features\n",
    "dataset.normalize_inputs(key=\"E_f\")  # Cache recreated automatically with scaled extra bond features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CuikmolmakerDataset (available with `cuik-molmaker` only)\n",
    "This dataset constructs and featurizes a batch of molecules at once instead of one at a time using `cuik-molmaker`. `CuikmolmakerDataset` implements `__getitems__` instead of `__getitem__` enabling batched dataset featurization and access. This method returns a `CuikBatchedDatum` which contains the same information as a `Datum`, except that the graph information is returned as a series of `Tensor`s instead of a `MolGraph` and each molecule's information is batched together."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuik-molmaker available: True\n"
     ]
    }
   ],
   "source": [
    "from chemprop.utils.utils import is_cuikmolmaker_available\n",
    "print(f\"cuik-molmaker available: {is_cuikmolmaker_available()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CuikBatchedDatum(atom_feats=tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1201],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1201],\n",
      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,\n",
      "         0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1201]]), bond_feats=tensor([[0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
      "        [0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]]), edge_index=tensor([[1, 2],\n",
      "        [2, 1]]), rev_edge_index=tensor([1, 0]), batch=tensor([0, 1, 1]), V_d=None, X_d=array([None, None], dtype=object), Y=array([[1.],\n",
      "       [2.]]), weights=array([1., 1.]), lt_mask=array([None, None], dtype=object), gt_mask=array([None, None], dtype=object))\n"
     ]
    }
   ],
   "source": [
    "if is_cuikmolmaker_available():\n",
    "    cuik_dataset = CuikmolmakerDataset(lazy_mol_datapoints)\n",
    "    print(cuik_dataset.__getitems__([0, 1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Datasets with custom featurizers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Datasets use a molgraph featurizer to create the `MolGraphs`s from the `rdkit.Chem.Mol` objects in datapoints. A basic `SimpleMoleculeMolGraphFeaturizer` is the default featurizer for `MoleculeDataset`s. If you are using a [custom molgraph featurizer](../featurizers/molgraph_molecule_featurizer.ipynb), pass it as an argument when creating the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MoleculeDataset(data=[MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed5b0>, y=array([0.23384385]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='C', V_f=None, E_f=None, V_d=None), MoleculeDatapoint(mol=<rdkit.Chem.rdchem.Mol object at 0x7f7a6b6ed690>, y=array([0.74433064]), weight=1.0, gt_mask=None, lt_mask=None, x_d=None, x_phase=None, name='CC', V_f=None, E_f=None, V_d=None)], featurizer=SimpleMoleculeMolGraphFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f7a6b538a50>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f7a6b538f50>))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from chemprop.featurizers import SimpleMoleculeMolGraphFeaturizer, MultiHotAtomFeaturizer\n",
    "\n",
    "mol_featurizer = SimpleMoleculeMolGraphFeaturizer(atom_featurizer=MultiHotAtomFeaturizer.v1())\n",
    "MoleculeDataset(mol_datapoints, featurizer=mol_featurizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Reaction Datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reaction datasets are the same as molecule datasets, except they are made from a list of `ReactionDatapoint`s and `CondensedGraphOfReactionFeaturizer` is the default featurizer. [CGRs](../featurizers/molgraph_reaction_featurizer.ipynb) are also `MolGraph`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CondensedGraphOfReactionFeaturizer(atom_featurizer=<chemprop.featurizers.atom.MultiHotAtomFeaturizer object at 0x7f7a6b53ab10>, bond_featurizer=<chemprop.featurizers.bond.MultiHotBondFeaturizer object at 0x7f7a6b53a8d0>)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ReactionDataset(rxn_datapoints).featurizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multicomponent datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`MulticomponentDataset` is for datasets whose target values depend on multiple components. It is composed of parallel `MoleculeDataset`s and `ReactionDataset`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<chemprop.data.datasets.MulticomponentDataset at 0x7f7a6b53bb90>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mol_dataset = MoleculeDataset(mol_datapoints)\n",
    "rxn_dataset = ReactionDataset(rxn_datapoints)\n",
    "\n",
    "# e.g. reaction in solvent\n",
    "multi_dataset = MulticomponentDataset(datasets=[mol_dataset, rxn_dataset])\n",
    "\n",
    "# e.g. solubility\n",
    "MulticomponentDataset(datasets=[mol_dataset, mol_dataset])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A `MulticomponentDataset` collates dataset properties (e.g. SMILES) of each dataset. It does not collate datapoint level properties like target values and extra datapoint descriptors. Chemprop models automatically take those from **the first dataset** in datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('C', ('[O:1]([H:2])[H:3]', '[H:3].[O:1][H:2]')),\n",
       " ('CC', ('[S:1]([H:2])[H:3]', '[H:3].[S:1][H:2]'))]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multi_dataset.smiles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.23384385],\n",
       "       [0.74433064]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "multi_dataset.datasets[0].Y"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
